Quantile Regression Forests vs. Random Forests#

An example comparison between the estimates generated by a quantile regression forest and a standard random forest regressor on a synthetic, right-skewed dataset. In a right-skewed distribution, the mean is to the right of the median. As illustrated by a greater overlap in the frequencies of the actual and predicted values, the median (quantile = 0.5) estimated by a quantile regressor can be a more reliable estimator of a skewed distribution than the mean.

import altair as alt
import numpy as np
import pandas as pd
import scipy as sp
from sklearn.ensemble import RandomForestRegressor
from sklearn.model_selection import train_test_split
from sklearn.utils.validation import check_random_state

from quantile_forest import RandomForestQuantileRegressor

rng = check_random_state(0)

# Create right-skewed dataset.
n_samples = 5000
a, loc, scale = 5, -1, 1
skewnorm_rv = sp.stats.skewnorm(a, loc, scale)
skewnorm_rv.random_state = rng
y = skewnorm_rv.rvs(n_samples)
X = rng.randn(n_samples, 2) * y.reshape(-1, 1)

quantiles = list(np.arange(101) / 100)

regr_rf = RandomForestRegressor(n_estimators=10, random_state=0)
regr_qrf = RandomForestQuantileRegressor(n_estimators=10, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

regr_rf.fit(X_train, y_train)
regr_qrf.fit(X_train, y_train)

y_pred_rf = regr_rf.predict(X_test)  # standard RF predictions (mean)
y_pred_qrf = regr_qrf.predict(X_test, quantiles=quantiles)  # QRF predictions (quantiles)

legend = {
    "Actual": "#c0c0c0",
    "RF (Mean)": "#f2a619",
    "QRF (Median)": "#006aff",
}

df = pd.concat(
    [
        pd.DataFrame({"actual": y_test, "rf": y_pred_rf, "qrf": y_pred_qrf[..., q_idx]}).assign(
            quantile=quantile
        )
        for q_idx, quantile in enumerate(quantiles)
    ]
)


def plot_prediction_histograms(df, legend):
    slider = alt.binding_range(
        min=0,
        max=1,
        step=0.5 if len(quantiles) == 1 else 1 / (len(quantiles) - 1),
        name="Quantile: ",
    )

    q_val = alt.selection_point(
        value=0.5,
        bind=slider,
        fields=["quantile"],
    )

    click = alt.selection_point(fields=["label"], bind="legend")

    color = alt.condition(
        click,
        alt.Color("label:N", sort=list(legend.keys()), title=None),
        alt.value("lightgray"),
    )

    chart = (
        alt.Chart(df)
        .transform_filter(q_val)
        .transform_calculate(calculate=f"round(datum.actual * 10) / 10", as_="Actual")
        .transform_calculate(calculate=f"round(datum.rf * 10) / 10", as_="RF (Mean)")
        .transform_calculate(calculate=f"round(datum.qrf * 10) / 10", as_="QRF (Quantile)")
        .transform_fold(["Actual", "RF (Mean)", "QRF (Quantile)"], as_=["label", "value"])
        .mark_bar()
        .encode(
            x=alt.X(
                "value:N",
                axis=alt.Axis(
                    labelAngle=0,
                    labelExpr="datum.value % 0.5 == 0 ? datum.value : null",
                ),
                title="Actual and Predicted Target Values",
            ),
            y=alt.Y("count():Q", axis=alt.Axis(format=",d", title="Counts")),
            color=color,
            xOffset=alt.XOffset("label:N"),
            tooltip=[
                alt.Tooltip("label:N", title="Label"),
                alt.Tooltip("value:O", title="Value (binned)"),
                alt.Tooltip("count():Q", format=",d", title="Counts"),
            ],
        )
        .add_params(q_val, click)
        .configure_range(category=alt.RangeScheme(list(legend.values())))
        .properties(height=400, width=650)
    )
    return chart


chart = plot_prediction_histograms(df, legend)
chart